{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "264e72d4-495a-48be-b95d-c8d18f13fe66",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2\"\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4311b4ed-30f7-455d-9929-c7f32f1fc100",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/auto_gptq/nn_modules/triton_utils/kernels.py:411: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n",
      "  def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):\n",
      "/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/auto_gptq/nn_modules/triton_utils/kernels.py:419: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.\n",
      "  def backward(ctx, grad_output):\n",
      "/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/auto_gptq/nn_modules/triton_utils/kernels.py:461: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n",
      "  @custom_fwd(cast_inputs=torch.float16)\n",
      "CUDA extension not installed.\n",
      "CUDA extension not installed.\n",
      "/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/transformers/modeling_utils.py:5006: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n",
      "  warnings.warn(\n",
      "`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8e251f1599f44049d59e96342b01b04",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Now you do not need to add \"trust_remote_code=True\"\n",
    "# qwen_model_path = \"/file/ljw22/Qwen2.5-14B-Instruct\"\n",
    "qwen_model_path = \"/file/ljw22/Qwen2.5-72B-Instruct-GPTQ-Int4\"\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    qwen_model_path,\n",
    "    torch_dtype=\"auto\",\n",
    "    device_map=\"auto\",\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(qwen_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "73bc76f2-f0bc-4a1b-ae5b-49605422c3bd",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 17\u001b[0m\n\u001b[1;32m     13\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m tokenizer([text], return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mto(model\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     15\u001b[0m \u001b[38;5;66;03m# Directly use generate() and tokenizer.decode() to get the output.\u001b[39;00m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;66;03m# Use `max_new_tokens` to control the maximum output length.\u001b[39;00m\n\u001b[0;32m---> 17\u001b[0m generated_ids \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     18\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     19\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m generated_ids \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m     22\u001b[0m     output_ids[\u001b[38;5;28mlen\u001b[39m(input_ids):] \u001b[38;5;28;01mfor\u001b[39;00m input_ids, output_ids \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(model_inputs\u001b[38;5;241m.\u001b[39minput_ids, generated_ids)\n\u001b[1;32m     23\u001b[0m ]\n\u001b[1;32m     25\u001b[0m response \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(generated_ids, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n",
      "File \u001b[0;32m/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    115\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/transformers/generation/utils.py:2215\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m   2207\u001b[0m     input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m   2208\u001b[0m         input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m   2209\u001b[0m         expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[1;32m   2210\u001b[0m         is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m   2211\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m   2212\u001b[0m     )\n\u001b[1;32m   2214\u001b[0m     \u001b[38;5;66;03m# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[0;32m-> 2215\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2216\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2217\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2218\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2219\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2220\u001b[0m \u001b[43m        \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2221\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2222\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2223\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2225\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[1;32m   2226\u001b[0m     \u001b[38;5;66;03m# 11. prepare beam search scorer\u001b[39;00m\n\u001b[1;32m   2227\u001b[0m     beam_scorer \u001b[38;5;241m=\u001b[39m BeamSearchScorer(\n\u001b[1;32m   2228\u001b[0m         batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[1;32m   2229\u001b[0m         num_beams\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_beams,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   2234\u001b[0m         max_length\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mmax_length,\n\u001b[1;32m   2235\u001b[0m     )\n",
      "File \u001b[0;32m/file/ljw22/miniconda3/envs/cu121/lib/python3.10/site-packages/transformers/generation/utils.py:3249\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)\u001b[0m\n\u001b[1;32m   3247\u001b[0m     probs \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(next_token_scores, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m   3248\u001b[0m     \u001b[38;5;66;03m# TODO (joao): this OP throws \"skipping cudagraphs due to ['incompatible ops']\", find solution\u001b[39;00m\n\u001b[0;32m-> 3249\u001b[0m     next_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmultinomial\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprobs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m   3250\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3251\u001b[0m     next_tokens \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(next_token_scores, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Instead of using model.chat(), we directly use model.generate()\n",
    "# But you need to use tokenizer.apply_chat_template() to format your inputs as shown below\n",
    "prompt = \"3.14和3.8哪个数字更大？\"\n",
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "    {\"role\": \"user\", \"content\": prompt},\n",
    "]\n",
    "text = tokenizer.apply_chat_template(\n",
    "    messages,\n",
    "    tokenize=False,\n",
    "    add_generation_prompt=True,\n",
    ")\n",
    "model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)\n",
    "\n",
    "# Directly use generate() and tokenizer.decode() to get the output.\n",
    "# Use `max_new_tokens` to control the maximum output length.\n",
    "generated_ids = model.generate(\n",
    "    **model_inputs,\n",
    "    max_new_tokens=512,\n",
    ")\n",
    "generated_ids = [\n",
    "    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n",
    "]\n",
    "\n",
    "response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2cf03f3e-2ff2-4071-b607-67953f7b5af7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "要比较3.14和3.8的大小，可以按照以下步骤进行：\n",
      "\n",
      "1. **比较整数部分**：\n",
      "   - 3.14 的整数部分是 3。\n",
      "   - 3.8 的整数部分也是 3。\n",
      "\n",
      "2. **比较小数部分**：\n",
      "   - 3.14 的小数部分是 0.14。\n",
      "   - 3.8 的小数部分是 0.8。\n",
      "\n",
      "3. **比较小数部分的数值**：\n",
      "   - 0.14 和 0.8 比较，显然 0.8 更大。\n",
      "\n",
      "因此，3.8 比 3.14 大。\n"
     ]
    }
   ],
   "source": [
    "# Reuse the code before `model.generate()` in the last code snippet\n",
    "from transformers import TextStreamer\n",
    "\n",
    "streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
    "generated_ids = model.generate(\n",
    "    **model_inputs,\n",
    "    max_new_tokens=512,\n",
    "    \n",
    "    streamer=streamer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f62a62e3-b5bf-40a3-9d83-523eeac78bd6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
